package com.github.hgkmail.hello.leetcode101.pointer.tree;

import com.github.hgkmail.hello.leetcode101.base.CommonUtil;
import com.github.hgkmail.hello.leetcode101.base.ListNode;
import com.github.hgkmail.hello.leetcode101.base.TreeNode;

import java.util.*;

public class LC1110DeleteNodesAndReturnForest {
    //先序遍历
    //todo 写个后序遍历版本，可以简化代码
    public TreeNode dfs(TreeNode root, Set<Integer> deletes, List<TreeNode> res) {
        if (root==null) {
            return null;
        }
        if (deletes.contains(root.val)) {
            //叶子节点
            if (root.left==null && root.right==null) {
                return null;
            }
            //内部节点
            if (root.left!=null) {
                TreeNode l=dfs(root.left, deletes, res);
                if (l!=null) {
                    res.add(l);
                }
            }
            if (root.right!=null) {
                TreeNode r=dfs(root.right, deletes, res);
                if (r!=null) {
                    res.add(r);
                }
            }
            return null;
        }
        root.left=dfs(root.left, deletes, res);
        root.right=dfs(root.right, deletes, res);

        return root;
    }

    //分2种情况：删叶子节点、删内部节点（左右子树从原树脱离）
    public List<TreeNode> delNodes(TreeNode root, int[] to_delete) {
        List<TreeNode> res=new ArrayList<>();
        Set<Integer> deletes = new HashSet<>();
        Arrays.stream(to_delete).forEach(deletes::add);
        dfs(root, deletes, res);
        if(root!=null && !deletes.contains(root.val)) {
            res.add(root);
        }
        return res;
    }

    public static void main(String[] args) {
        TreeNode root= CommonUtil.deserializeBinaryTree("1,2,#,#,3,#,4");
        List<TreeNode> roots = new LC1110DeleteNodesAndReturnForest().delNodes(root, new int[]{2,1});
        for (TreeNode r:roots) {
            System.out.println(CommonUtil.serializeBinaryTree(r));
        }
    }
}
